{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "IcwwhIrLwMZc",
      "metadata": {
        "id": "IcwwhIrLwMZc"
      },
      "source": [
        "Copyright 2023 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "258ca2e3",
      "metadata": {
        "id": "258ca2e3"
      },
      "outputs": [],
      "source": [
        "from diffusers import StableDiffusionPipeline\n",
        "from diffusers.schedulers import LMSDiscreteScheduler\n",
        "from random import randrange\n",
        "import torch\n",
        "from diffusers import StableDiffusionPipeline\n",
        "from diffusers.schedulers import LMSDiscreteScheduler, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler\n",
        "import timm\n",
        "import torchvision.transforms as transforms\n",
        "from random import randrange\n",
        "import requests\n",
        "import torch.optim as optim\n",
        "import numpy as np\n",
        "from PIL import Image, ImageDraw\n",
        "import numpy as np\n",
        "import argparse\n",
        "import os\n",
        "import glob\n",
        "from pathlib import Path"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f46b8c8c",
      "metadata": {
        "id": "f46b8c8c"
      },
      "outputs": [],
      "source": [
        "# initialize stable diffusion pipeline\n",
        "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
        "pipe.to(\"cuda\")\n",
        "scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\n",
        "pipe.scheduler = scheduler\n",
        "orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()\n",
        "pipe.text_encoder.text_model.embeddings.token_embedding.weight.requires_grad_(False)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "14ba3600",
      "metadata": {
        "id": "14ba3600"
      },
      "source": [
        "# Load decomposition results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "690fa1ba",
      "metadata": {
        "id": "690fa1ba"
      },
      "outputs": [],
      "source": [
        "concept = 'dog'\n",
        "folder = f'./{concept}'\n",
        "# load coefficients\n",
        "alphas_dict = torch.load(f'{folder}/best_alphas.pt').detach_().requires_grad_(False)\n",
        "# load vocabulary\n",
        "dictionary = torch.load(f'{folder}/dictionary.pt')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4a77cc1a",
      "metadata": {
        "id": "4a77cc1a"
      },
      "source": [
        "# Visualize top coefficients and top tokens"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2cef1f3c",
      "metadata": {
        "id": "2cef1f3c"
      },
      "outputs": [],
      "source": [
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)\n",
        "num_indices=10\n",
        "top_indices_orig_dict = [dictionary[i] for i in sorted_indices[:num_indices]]\n",
        "print(\"top coefficients: \", sorted_alphas[:num_indices].cpu().numpy())\n",
        "alpha_ids = [pipe.tokenizer.decode(idx) for idx in top_indices_orig_dict]\n",
        "print(\"top tokens: \", alpha_ids)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bdac204b",
      "metadata": {
        "id": "bdac204b"
      },
      "source": [
        "# Extract top 50 tokens"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f3dbfc77",
      "metadata": {
        "id": "f3dbfc77"
      },
      "outputs": [],
      "source": [
        "num_tokens = 50\n",
        "alphas = torch.zeros(orig_embeddings.shape[0]).cuda()\n",
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict.abs(), descending=True)\n",
        "top_word_idx = [dictionary[i] for i in sorted_indices[:num_tokens]]\n",
        "for i,index in enumerate(top_word_idx):\n",
        "    alphas[index] = alphas_dict[sorted_indices[i]]\n",
        "\n",
        "# add placeholder for w^*\n",
        "placeholder_token = '\u003c\u003e'\n",
        "pipe.tokenizer.add_tokens(placeholder_token)\n",
        "placeholder_token_id = pipe.tokenizer.convert_tokens_to_ids(placeholder_token)\n",
        "pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))\n",
        "token_embeds = pipe.text_encoder.get_input_embeddings().weight.detach().requires_grad_(False)\n",
        "\n",
        "# compute w^* and normalize its embedding\n",
        "learned_embedding = torch.matmul(alphas, orig_embeddings).flatten()\n",
        "norms = [i.norm().item() for i in orig_embeddings]\n",
        "avg_norm = np.mean(norms)\n",
        "learned_embedding /= learned_embedding.norm()\n",
        "learned_embedding *= avg_norm\n",
        "\n",
        "# add w^* to vocabulary\n",
        "token_embeds[placeholder_token_id] = torch.nn.Parameter(learned_embedding)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b050e78f",
      "metadata": {
        "id": "b050e78f"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "def get_image_grid(images) -\u003e Image:\n",
        "    num_images = len(images)\n",
        "    cols = int(math.ceil(math.sqrt(num_images)))\n",
        "    rows = int(math.ceil(num_images / cols))\n",
        "    width, height = images[0].size\n",
        "    grid_image = Image.new('RGB', (cols * width, rows * height))\n",
        "    for i, img in enumerate(images):\n",
        "        x = i % cols\n",
        "        y = i // cols\n",
        "        grid_image.paste(img, (x * width, y * height))\n",
        "    return grid_image"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8a24b91a",
      "metadata": {
        "id": "8a24b91a"
      },
      "source": [
        "# Reconstruction results- first 6 images of seed 0 (no cherry picking)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5ab6a7a7",
      "metadata": {
        "id": "5ab6a7a7",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "prompt = 'a photo of a \u003c\u003e'\n",
        "\n",
        "generator = torch.Generator(\"cuda\").manual_seed(0)\n",
        "image = pipe(prompt,\n",
        "             guidance_scale=7.5,\n",
        "             generator=generator,\n",
        "             return_dict=False,\n",
        "             num_images_per_prompt=6,\n",
        "            num_inference_steps=50)\n",
        "display(get_image_grid(image[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dcfa03da",
      "metadata": {
        "id": "dcfa03da",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "prompt = 'a photo of a dog'\n",
        "generator = torch.Generator(\"cuda\").manual_seed(0)\n",
        "image = pipe(prompt,\n",
        "             guidance_scale=7.5,\n",
        "             generator=generator,\n",
        "             return_dict=False,\n",
        "             num_images_per_prompt=6,\n",
        "            num_inference_steps=50)\n",
        "display(get_image_grid(image[0]))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
